{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Kubeflow Pipelines - Retail Product Stockouts Prediction using AutoML Tables\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Configuration\n",
    "\n",
    "PROJECT_ID = \"<project-id-123456>\"\n",
    "COMPUTE_REGION = \"us-central1\" # Currently \"us-central1\" is the only region supported by AutoML tables.\n",
    "# The bucket must be Regional (not multi-regional) and the region should be us-central1. This is a limitation of the batch prediction service.\n",
    "batch_predict_gcs_output_uri_prefix = 'gs://<gcs-bucket-regional-us-central1>/<subpath>/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# AutoML Tables components\n",
    "\n",
    "from kfp.components import load_component_from_url\n",
    "\n",
    "automl_create_dataset_for_tables_op        = load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_dataset_for_tables/component.yaml')\n",
    "automl_import_data_from_bigquery_source_op = load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/import_data_from_bigquery/component.yaml')\n",
    "automl_create_model_for_tables_op          = load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_model_for_tables/component.yaml')\n",
    "automl_prediction_service_batch_predict_op = load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/prediction_service_batch_predict/component.yaml')\n",
    "automl_split_dataset_table_column_names_op = load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/split_dataset_table_column_names/component.yaml')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the pipeline\n",
    "import kfp\n",
    "\n",
    "def retail_product_stockout_prediction_pipeline(\n",
    "    gcp_project_id: str,\n",
    "    gcp_region: str,\n",
    "    batch_predict_gcs_output_uri_prefix: str,\n",
    "    dataset_bq_input_uri: str = 'bq://product-stockout.product_stockout.stockout',\n",
    "    dataset_display_name: str = 'stockout_data',\n",
    "    target_column_name: str = 'Stockout',\n",
    "    model_display_name: str = 'stockout_model',\n",
    "    batch_predict_bq_input_uri: str = 'bq://product-stockout.product_stockout.batch_prediction_inputs',\n",
    "    train_budget_milli_node_hours: 'Integer' = 1000,\n",
    "):\n",
    "    # Create dataset\n",
    "    create_dataset_task = automl_create_dataset_for_tables_op(\n",
    "        gcp_project_id=gcp_project_id,\n",
    "        gcp_region=gcp_region,\n",
    "        display_name=dataset_display_name,\n",
    "    )\n",
    "\n",
    "    # Import data\n",
    "    import_data_task = automl_import_data_from_bigquery_source_op(\n",
    "        dataset_path=create_dataset_task.outputs['dataset_path'],\n",
    "        input_uri=dataset_bq_input_uri,\n",
    "    )\n",
    "    \n",
    "    # Prepare column schemas\n",
    "    split_column_specs = automl_split_dataset_table_column_names_op(\n",
    "        dataset_path=import_data_task.outputs['dataset_path'],\n",
    "        table_index=0,\n",
    "        target_column_name=target_column_name,\n",
    "    )\n",
    "    \n",
    "    # Train a model\n",
    "    create_model_task = automl_create_model_for_tables_op(\n",
    "        gcp_project_id=gcp_project_id,\n",
    "        gcp_region=gcp_region,\n",
    "        display_name=model_display_name,\n",
    "        dataset_id=create_dataset_task.outputs['dataset_id'],\n",
    "        target_column_path=split_column_specs.outputs['target_column_path'],\n",
    "        #input_feature_column_paths=None, # All non-target columns will be used if None is passed\n",
    "        input_feature_column_paths=split_column_specs.outputs['feature_column_paths'],\n",
    "        optimization_objective='MAXIMIZE_AU_PRC',\n",
    "        train_budget_milli_node_hours=train_budget_milli_node_hours,\n",
    "    ).after(import_data_task)\n",
    "\n",
    "    # Batch prediction\n",
    "    batch_predict_task = automl_prediction_service_batch_predict_op(\n",
    "        model_path=create_model_task.outputs['model_path'],\n",
    "        bq_input_uri=batch_predict_bq_input_uri, \n",
    "        gcs_output_uri_prefix=batch_predict_gcs_output_uri_prefix,\n",
    "    )\n",
    "    \n",
    "    # The pipeline should be able to authenticate to GCP.\n",
    "    # Refer to [Authenticating Pipelines to GCP](https://www.kubeflow.org/docs/gke/authentication-pipelines/) for details.\n",
    "    #\n",
    "    # For example, you may uncomment the following lines to use GSA keys.\n",
    "    # from kfp.gcp import use_gcp_secret\n",
    "    # kfp.dsl.get_pipeline_conf().add_op_transformer(use_gcp_secret('user-gcp-sa'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run the pipeline\n",
    "\n",
    "# Get the GCP location of your project.\n",
    "from google.cloud import automl\n",
    "location_path = automl.AutoMlClient().location_path(PROJECT_ID, COMPUTE_REGION)\n",
    "\n",
    "kfp.run_pipeline_func_on_cluster(\n",
    "    retail_product_stockout_prediction_pipeline,\n",
    "    arguments=dict(\n",
    "        gcp_project_id=PROJECT_ID,\n",
    "        gcp_region=COMPUTE_REGION,\n",
    "        dataset_bq_input_uri='bq://product-stockout.product_stockout.stockout',\n",
    "        batch_predict_bq_input_uri='bq://product-stockout.product_stockout.batch_prediction_inputs',\n",
    "        batch_predict_gcs_output_uri_prefix=batch_predict_gcs_output_uri_prefix,\n",
    "    )\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
